# List of required packages
required_packages <- c(
  "tidyverse",    # Data manipulation and visualization
  "jtools",       # For summ() function and regression summaries
  "broom",        # For tidy() function and model output formatting
  "knitr",        # Table formatting
  "kableExtra"    # Enhanced table formatting
)

# Install any missing packages
new_packages <- required_packages[!(required_packages %in% installed.packages()[, "Package"])]
if (length(new_packages)) install.packages(new_packages)

# Load all required packages
lapply(required_packages, library, character.only = TRUE)

# Load and preprocess data
data <- read_csv("merged_survey_data.csv") %>%
  drop_na(vote_recall) %>%  # Changed from party_pref to vote_recall
  drop_na(ptv_salf)

# Define colors for the political parties
party_colors <- c(
  "PSOE" = "#E41A1C",                 # Red
  "Sumar" = "#FF69B4",                # Hot Pink
  "Podemos" = "#984EA3",              # Purple
  "VOX" = "#4DAF4A",                  # Green
  "PP" = "#377EB8",                   # Blue
  "SALF" = "#A6761D"                  # Brown
)

# Create a named vector for survey labels
survey_labels <- c(
  "db40_july" = "40dB July",
  "db40_august" = "40dB August",
  "Overall" = "Overall"
)

# Define shapes for each survey
survey_shapes <- c(
  "db40_july" = 0,       # Square (matching the image)
  "db40_august" = 5,     # Diamond (matching the image)
  "Overall" = 16         # Filled circle for overall
)

# Get unique surveys to use
unique_surveys <- unique(data$survey)
# Filter to only relevant surveys
relevant_surveys <- unique_surveys[grepl("db40_", unique_surveys)]

# Create output directories
dir.create("outputs", showWarnings = FALSE)
dir.create("outputs/tables", showWarnings = FALSE)
dir.create("outputs/figures", showWarnings = FALSE)

# Filter data to include only parties in our color scheme
filtered_data <- data %>%
  filter(vote_recall %in% names(party_colors), survey %in% relevant_surveys)

# Create a list to store the regression models
reg_models <- list()

# Run separate regressions for each survey
for (surv in relevant_surveys) {
  surv_data <- filtered_data %>% filter(survey == surv)
  reg_models[[surv]] <- lm(ptv_salf ~ vote_recall + age + gender + edu_fct, data = surv_data)
}

# Run an overall regression with survey as a control
reg_models[["Overall"]] <- lm(ptv_salf ~ vote_recall + age + gender + edu_fct + survey, 
                              data = filtered_data)

# Create prediction data for each model
all_predictions <- list()

for (model_name in names(reg_models)) {
  model <- reg_models[[model_name]]
  
  # Create prediction data based on model terms
  if (model_name == "Overall") {
    # For overall model, we need to include both surveys
    new_data <- expand_grid(
      vote_recall = names(party_colors),
      survey = relevant_surveys[1], # We'll use the first survey for prediction
      age = mean(filtered_data$age, na.rm = TRUE),
      gender = names(which.max(table(filtered_data$gender))),
      edu_fct = names(which.max(table(filtered_data$edu_fct)))
    )
  } else {
    # For individual survey models
    new_data <- expand_grid(
      vote_recall = names(party_colors),
      age = mean(filtered_data$age, na.rm = TRUE),
      gender = names(which.max(table(filtered_data$gender))),
      edu_fct = names(which.max(table(filtered_data$edu_fct)))
    )
  }
  
  # Get predictions
  preds <- predict(model, newdata = new_data, se.fit = TRUE)
  
  # Add predictions to data
  predictions <- new_data %>%
    mutate(
      predicted_ptv = preds$fit,
      se = preds$se.fit,
      model = model_name
    )
  
  all_predictions[[model_name]] <- predictions
}

# Combine all predictions
plot_data <- bind_rows(all_predictions)

# Calculate survey sizes
survey_totals <- filtered_data %>%
  count(survey, name = "total_sample_size")

# Add overall as a separate entry
survey_totals <- survey_totals %>%
  bind_rows(
    data.frame(
      survey = "Overall",
      total_sample_size = sum(survey_totals$total_sample_size)
    )
  )

# Calculate party-specific sample sizes
party_sizes <- filtered_data %>%
  count(survey, vote_recall, name = "sample_size")

# Add overall party sizes
overall_party_sizes <- filtered_data %>%
  count(vote_recall, name = "sample_size") %>%
  mutate(survey = "Overall")

# Combine all sample sizes
all_sample_sizes <- bind_rows(party_sizes, overall_party_sizes)

# Join to create final plotting data
plot_data <- plot_data %>%
  left_join(all_sample_sizes, by = c("model" = "survey", "vote_recall")) %>%
  left_join(survey_totals, by = c("model" = "survey"))

# Ensure model factor is ordered correctly for the legend
model_order <- c("Overall", relevant_surveys)
plot_data$model <- factor(plot_data$model, levels = model_order)

# Order parties by their overall predicted PTV
party_order <- plot_data %>%
  filter(model == "Overall") %>%
  arrange(predicted_ptv) %>%
  pull(vote_recall)

# Set factor levels for vote_recall
plot_data$vote_recall <- factor(plot_data$vote_recall, levels = party_order)

# Create a summary data frame for plotting
mean_ptv <- plot_data %>%
  group_by(vote_recall, model) %>%
  summarise(
    mean_ptv = mean(predicted_ptv, na.rm = TRUE),
    se = mean(se, na.rm = TRUE),
    sample_size = first(sample_size),
    total_sample_size = first(total_sample_size),
    mean_label = round(mean_ptv, 1),
    .groups = "drop"
  )

# Now create the plot
p <- ggplot() +
  # Add dotted reference line at x=5
  geom_vline(xintercept = 5, linetype = "dotted", color = "gray") +
  # Add error bars
  geom_linerange(data = mean_ptv, 
                 aes(y = vote_recall, xmin = mean_ptv - se, xmax = mean_ptv + se, color = vote_recall),
                 size = 0.5, alpha = 0.7) +
  # Add hollow points for individual surveys
  geom_point(data = mean_ptv %>% filter(model != "Overall"), 
             aes(x = mean_ptv, y = vote_recall, color = vote_recall, size = total_sample_size, shape = model),
             alpha = 0.7) +
  # Add filled points for overall
  geom_point(data = mean_ptv %>% filter(model == "Overall"), 
             aes(x = mean_ptv, y = vote_recall, color = vote_recall, size = total_sample_size),
             shape = 16, alpha = 0.7) +
  # Set scales and guides
  scale_color_manual(values = party_colors, guide = "none") +
  scale_shape_manual(values = survey_shapes, name = "Survey", labels = survey_labels) +
  scale_size_continuous(name = "Survey Sample Size (N)", 
                        breaks = c(500, 1000, 2000),
                        labels = scales::comma(c(500, 1000, 2000)),
                        range = c(3, 6)) +
  # Theme and other elements
  theme_minimal() +
  theme(
    panel.grid.minor = element_blank(),
    legend.position = "right",
    plot.title = element_text(hjust = 0.5, size = 14)
  ) +
  labs(title= "Propensity to Vote for SALF", 
       x = "Predicted Values", 
       y = "",caption ="Note: The independent variable is vote recall at the EP 2024 Elections"
  )

# Save the plot
ggsave("outputs/figures/ptv_salf.png", p, width = 8, height = 6, bg = "white")

#==========================================================================================
# MODEL FIT STATISTICS AND COEFFICIENT TABLE IN LATEX
#==========================================================================================

# Function to extract model fit statistics
get_model_fit <- function(model) {
  tibble(
    R_squared = summary(model)$r.squared,
    Adj_R_squared = summary(model)$adj.r.squared,
    F_statistic = summary(model)$fstatistic[1],
    df1 = summary(model)$fstatistic[2],
    df2 = summary(model)$fstatistic[3],
    p_value = pf(summary(model)$fstatistic[1], 
                 summary(model)$fstatistic[2], 
                 summary(model)$fstatistic[3], 
                 lower.tail = FALSE)
    # Removed N from here to avoid duplication
  )
}

# Create model fit table manually to avoid naming conflicts
model_fit_table <- data.frame(
  Survey = character(),
  N = integer(),
  R_squared = numeric(),
  Adj_R_squared = numeric(),
  F_statistic = numeric(),
  df1 = numeric(),
  df2 = numeric(),
  p_value = numeric(),
  stringsAsFactors = FALSE
)

# Fill the table one row at a time
for (model_name in names(reg_models)) {
  model <- reg_models[[model_name]]
  fit_stats <- get_model_fit(model)
  
  # Get data size
  if(model_name == "Overall") {
    n_obs <- nrow(filtered_data)
  } else {
    n_obs <- sum(filtered_data$survey == model_name)
  }
  
  # Create a single row
  new_row <- data.frame(
    Survey = model_name,
    N = n_obs,
    R_squared = fit_stats$R_squared,
    Adj_R_squared = fit_stats$Adj_R_squared,
    F_statistic = fit_stats$F_statistic,
    df1 = fit_stats$df1,
    df2 = fit_stats$df2,
    p_value = fit_stats$p_value,
    stringsAsFactors = FALSE
  )
  
  # Append the row
  model_fit_table <- rbind(model_fit_table, new_row)
}

# Replace survey names with survey labels
model_fit_table <- model_fit_table %>%
  mutate(Survey = ifelse(Survey %in% names(survey_labels), 
                         survey_labels[Survey], 
                         as.character(Survey)))

# Format as LaTeX table
latex_model_fit <- kable(model_fit_table, format = "latex", booktabs = TRUE,
                         caption = "Model Fit Statistics for Regression Models by Survey",
                         label = "tab:model_fit_ptv_salf",
                         col.names = c("Survey", "N", "R²", "Adjusted R²", "F-statistic", "df1", "df2", "p-value"),
                         digits = c(0, 0, 3, 3, 2, 0, 0, 4),
                         escape = FALSE) %>%
  kable_styling(latex_options = c("striped", "hold_position")) %>%
  footnote(
    general = "All models include vote recall, gender, age, and education as predictors. Overall model also includes survey as a control.",
    threeparttable = TRUE
  )

# Write to file
writeLines(latex_model_fit, "outputs/tables/model_fit_ptv_salf.tex")

# Function to extract and format model coefficients
extract_coefficients <- function(model, model_name) {
  coefs <- broom::tidy(model)
  
  # Format with stars
  coefs <- coefs %>%
    mutate(
      stars = case_when(
        p.value < 0.001 ~ "***",
        p.value < 0.01 ~ "**",
        p.value < 0.05 ~ "*",
        p.value < 0.1 ~ ".",
        TRUE ~ ""
      ),
      formatted = sprintf("%.3f%s\n(%.3f)", estimate, stars, std.error),
      model = model_name
    )
  
  return(coefs)
}

# Extract coefficients for all models
all_coefs <- map2_dfr(reg_models, names(reg_models), extract_coefficients)

# Create a filtered table focusing on party coefficients
party_coefs <- all_coefs %>%
  filter(grepl("vote_recall", term)) %>%
  # Extract the party name from the term
  mutate(
    party = gsub("vote_recall", "", term),
    # Remove any leading characters
    party = gsub("^[^A-Za-z]+", "", party)
  ) %>%
  # Select relevant columns and pivot to wide format
  select(model, party, formatted) %>%
  pivot_wider(
    id_cols = party,
    names_from = model,
    values_from = formatted
  )

# Rename columns using survey labels
party_coefs_renamed <- party_coefs %>%
  rename_with(~ ifelse(.x %in% names(survey_labels), survey_labels[.x], .x))

# Format as LaTeX table
latex_party_coefs <- kable(party_coefs_renamed, format = "latex", booktabs = TRUE,
                           caption = "Vote Recall Coefficients for PTV SALF Models",
                           label = "tab:coef_ptv_salf",
                           col.names = c("Party", sapply(model_order, function(x) ifelse(x %in% names(survey_labels), survey_labels[x], x))),
                           align = c("l", rep("c", length(model_order))),
                           escape = FALSE) %>%
  kable_styling(latex_options = c("striped", "hold_position")) %>%
  footnote(
    general = "Coefficients shown with standard errors in parentheses.",
    symbol = c("p < 0.1", "p < 0.05", "p < 0.01", "p < 0.001"),
    symbol_manual = c(".", "*", "**", "***"),
    threeparttable = TRUE
  )

# Write to file
writeLines(latex_party_coefs, "outputs/tables/party_coef_ptv_salf.tex")

# Create a full coefficients table
full_coefs <- all_coefs %>%
  # Clean up term names for better display
  mutate(
    term = gsub("vote_recall", "", term),
    term = gsub("^[^A-Za-z]+", "", term),
    term = ifelse(term == "(Intercept)", "Intercept", term)
  ) %>%
  # Select relevant columns and pivot to wide format
  select(term, model, formatted) %>%
  pivot_wider(
    id_cols = term,
    names_from = model,
    values_from = formatted
  )

# Rename columns using survey labels
full_coefs_renamed <- full_coefs %>%
  rename_with(~ ifelse(.x %in% names(survey_labels), survey_labels[.x], .x))

# Format as LaTeX table
latex_full_coefs <- kable(full_coefs_renamed, format = "latex", booktabs = TRUE,
                          caption = "Full Regression Coefficients for PTV SALF Models",
                          label = "tab:full_coef_ptv_salf",
                          col.names = c("Term", sapply(model_order, function(x) ifelse(x %in% names(survey_labels), survey_labels[x], x))),
                          align = c("l", rep("c", length(model_order))),
                          escape = FALSE) %>%
  kable_styling(latex_options = c("striped", "scale_down", "hold_position")) %>%
  footnote(
    general = "Coefficients shown with standard errors in parentheses.",
    symbol = c("p < 0.1", "p < 0.05", "p < 0.01", "p < 0.001"),
    symbol_manual = c(".", "*", "**", "***"),
    threeparttable = TRUE
  )

# Write to file
writeLines(latex_full_coefs, "outputs/tables/full_coef_ptv_salf.tex")

# Print a message to indicate completion
cat("Analysis completed. Output files saved.")